{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import numpy.linalg as npl\n",
    "import scipy.linalg as spl\n",
    "import sklearn.metrics as sklm\n",
    "import math\n",
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams['figure.figsize'] = [6, 6]\n",
    "from IPython.core.interactiveshell import InteractiveShell\n",
    "InteractiveShell.ast_node_interactivity = \"all\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "irisD = {'virginica': np.array([[6.3, 3.3, 6. , 2.5],\n",
    "       [5.8, 2.7, 5.1, 1.9],\n",
    "       [7.1, 3. , 5.9, 2.1],\n",
    "       [6.3, 2.9, 5.6, 1.8],\n",
    "       [6.5, 3. , 5.8, 2.2],\n",
    "       [7.6, 3. , 6.6, 2.1],\n",
    "       [4.9, 2.5, 4.5, 1.7],\n",
    "       [7.3, 2.9, 6.3, 1.8],\n",
    "       [6.7, 2.5, 5.8, 1.8],\n",
    "       [7.2, 3.6, 6.1, 2.5],\n",
    "       [6.5, 3.2, 5.1, 2. ],\n",
    "       [6.4, 2.7, 5.3, 1.9],\n",
    "       [6.8, 3. , 5.5, 2.1],\n",
    "       [5.7, 2.5, 5. , 2. ],\n",
    "       [5.8, 2.8, 5.1, 2.4],\n",
    "       [6.4, 3.2, 5.3, 2.3],\n",
    "       [6.5, 3. , 5.5, 1.8],\n",
    "       [7.7, 3.8, 6.7, 2.2],\n",
    "       [7.7, 2.6, 6.9, 2.3],\n",
    "       [6. , 2.2, 5. , 1.5],\n",
    "       [6.9, 3.2, 5.7, 2.3],\n",
    "       [5.6, 2.8, 4.9, 2. ],\n",
    "       [7.7, 2.8, 6.7, 2. ],\n",
    "       [6.3, 2.7, 4.9, 1.8],\n",
    "       [6.7, 3.3, 5.7, 2.1],\n",
    "       [7.2, 3.2, 6. , 1.8],\n",
    "       [6.2, 2.8, 4.8, 1.8],\n",
    "       [6.1, 3. , 4.9, 1.8],\n",
    "       [6.4, 2.8, 5.6, 2.1],\n",
    "       [7.2, 3. , 5.8, 1.6],\n",
    "       [7.4, 2.8, 6.1, 1.9],\n",
    "       [7.9, 3.8, 6.4, 2. ],\n",
    "       [6.4, 2.8, 5.6, 2.2],\n",
    "       [6.3, 2.8, 5.1, 1.5],\n",
    "       [6.1, 2.6, 5.6, 1.4],\n",
    "       [7.7, 3. , 6.1, 2.3],\n",
    "       [6.3, 3.4, 5.6, 2.4],\n",
    "       [6.4, 3.1, 5.5, 1.8],\n",
    "       [6. , 3. , 4.8, 1.8],\n",
    "       [6.9, 3.1, 5.4, 2.1],\n",
    "       [6.7, 3.1, 5.6, 2.4],\n",
    "       [6.9, 3.1, 5.1, 2.3],\n",
    "       [5.8, 2.7, 5.1, 1.9],\n",
    "       [6.8, 3.2, 5.9, 2.3],\n",
    "       [6.7, 3.3, 5.7, 2.5],\n",
    "       [6.7, 3. , 5.2, 2.3],\n",
    "       [6.3, 2.5, 5. , 1.9],\n",
    "       [6.5, 3. , 5.2, 2. ],\n",
    "       [6.2, 3.4, 5.4, 2.3],\n",
    "       [5.9, 3. , 5.1, 1.8]]), 'setosa': np.array([[5.1, 3.5, 1.4, 0.2],\n",
    "       [4.9, 3. , 1.4, 0.2],\n",
    "       [4.7, 3.2, 1.3, 0.2],\n",
    "       [4.6, 3.1, 1.5, 0.2],\n",
    "       [5. , 3.6, 1.4, 0.2],\n",
    "       [5.4, 3.9, 1.7, 0.4],\n",
    "       [4.6, 3.4, 1.4, 0.3],\n",
    "       [5. , 3.4, 1.5, 0.2],\n",
    "       [4.4, 2.9, 1.4, 0.2],\n",
    "       [4.9, 3.1, 1.5, 0.1],\n",
    "       [5.4, 3.7, 1.5, 0.2],\n",
    "       [4.8, 3.4, 1.6, 0.2],\n",
    "       [4.8, 3. , 1.4, 0.1],\n",
    "       [4.3, 3. , 1.1, 0.1],\n",
    "       [5.8, 4. , 1.2, 0.2],\n",
    "       [5.7, 4.4, 1.5, 0.4],\n",
    "       [5.4, 3.9, 1.3, 0.4],\n",
    "       [5.1, 3.5, 1.4, 0.3],\n",
    "       [5.7, 3.8, 1.7, 0.3],\n",
    "       [5.1, 3.8, 1.5, 0.3],\n",
    "       [5.4, 3.4, 1.7, 0.2],\n",
    "       [5.1, 3.7, 1.5, 0.4],\n",
    "       [4.6, 3.6, 1. , 0.2],\n",
    "       [5.1, 3.3, 1.7, 0.5],\n",
    "       [4.8, 3.4, 1.9, 0.2],\n",
    "       [5. , 3. , 1.6, 0.2],\n",
    "       [5. , 3.4, 1.6, 0.4],\n",
    "       [5.2, 3.5, 1.5, 0.2],\n",
    "       [5.2, 3.4, 1.4, 0.2],\n",
    "       [4.7, 3.2, 1.6, 0.2],\n",
    "       [4.8, 3.1, 1.6, 0.2],\n",
    "       [5.4, 3.4, 1.5, 0.4],\n",
    "       [5.2, 4.1, 1.5, 0.1],\n",
    "       [5.5, 4.2, 1.4, 0.2],\n",
    "       [4.9, 3.1, 1.5, 0.2],\n",
    "       [5. , 3.2, 1.2, 0.2],\n",
    "       [5.5, 3.5, 1.3, 0.2],\n",
    "       [4.9, 3.6, 1.4, 0.1],\n",
    "       [4.4, 3. , 1.3, 0.2],\n",
    "       [5.1, 3.4, 1.5, 0.2],\n",
    "       [5. , 3.5, 1.3, 0.3],\n",
    "       [4.5, 2.3, 1.3, 0.3],\n",
    "       [4.4, 3.2, 1.3, 0.2],\n",
    "       [5. , 3.5, 1.6, 0.6],\n",
    "       [5.1, 3.8, 1.9, 0.4],\n",
    "       [4.8, 3. , 1.4, 0.3],\n",
    "       [5.1, 3.8, 1.6, 0.2],\n",
    "       [4.6, 3.2, 1.4, 0.2],\n",
    "       [5.3, 3.7, 1.5, 0.2],\n",
    "       [5. , 3.3, 1.4, 0.2]]), 'versicolor': np.array([[7. , 3.2, 4.7, 1.4],\n",
    "       [6.4, 3.2, 4.5, 1.5],\n",
    "       [6.9, 3.1, 4.9, 1.5],\n",
    "       [5.5, 2.3, 4. , 1.3],\n",
    "       [6.5, 2.8, 4.6, 1.5],\n",
    "       [5.7, 2.8, 4.5, 1.3],\n",
    "       [6.3, 3.3, 4.7, 1.6],\n",
    "       [4.9, 2.4, 3.3, 1. ],\n",
    "       [6.6, 2.9, 4.6, 1.3],\n",
    "       [5.2, 2.7, 3.9, 1.4],\n",
    "       [5. , 2. , 3.5, 1. ],\n",
    "       [5.9, 3. , 4.2, 1.5],\n",
    "       [6. , 2.2, 4. , 1. ],\n",
    "       [6.1, 2.9, 4.7, 1.4],\n",
    "       [5.6, 2.9, 3.6, 1.3],\n",
    "       [6.7, 3.1, 4.4, 1.4],\n",
    "       [5.6, 3. , 4.5, 1.5],\n",
    "       [5.8, 2.7, 4.1, 1. ],\n",
    "       [6.2, 2.2, 4.5, 1.5],\n",
    "       [5.6, 2.5, 3.9, 1.1],\n",
    "       [5.9, 3.2, 4.8, 1.8],\n",
    "       [6.1, 2.8, 4. , 1.3],\n",
    "       [6.3, 2.5, 4.9, 1.5],\n",
    "       [6.1, 2.8, 4.7, 1.2],\n",
    "       [6.4, 2.9, 4.3, 1.3],\n",
    "       [6.6, 3. , 4.4, 1.4],\n",
    "       [6.8, 2.8, 4.8, 1.4],\n",
    "       [6.7, 3. , 5. , 1.7],\n",
    "       [6. , 2.9, 4.5, 1.5],\n",
    "       [5.7, 2.6, 3.5, 1. ],\n",
    "       [5.5, 2.4, 3.8, 1.1],\n",
    "       [5.5, 2.4, 3.7, 1. ],\n",
    "       [5.8, 2.7, 3.9, 1.2],\n",
    "       [6. , 2.7, 5.1, 1.6],\n",
    "       [5.4, 3. , 4.5, 1.5],\n",
    "       [6. , 3.4, 4.5, 1.6],\n",
    "       [6.7, 3.1, 4.7, 1.5],\n",
    "       [6.3, 2.3, 4.4, 1.3],\n",
    "       [5.6, 3. , 4.1, 1.3],\n",
    "       [5.5, 2.5, 4. , 1.3],\n",
    "       [5.5, 2.6, 4.4, 1.2],\n",
    "       [6.1, 3. , 4.6, 1.4],\n",
    "       [5.8, 2.6, 4. , 1.2],\n",
    "       [5. , 2.3, 3.3, 1. ],\n",
    "       [5.6, 2.7, 4.2, 1.3],\n",
    "       [5.7, 3. , 4.2, 1.2],\n",
    "       [5.7, 2.9, 4.2, 1.3],\n",
    "       [6.2, 2.9, 4.3, 1.3],\n",
    "       [5.1, 2.5, 3. , 1.1],\n",
    "       [5.7, 2.8, 4.1, 1.3]])}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 14.1 Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "array([ 1, -1,  1])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf2pm1 = lambda b:2*b-1 #convert T/F to 1 -1\n",
    "b=True\n",
    "tf2pm1(b)\n",
    "b = np.array([True,False,True])\n",
    "tf2pm1(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "numTP = lambda y,yhat: sum([1 for i in range(len(y)) if y[i] == True and yhat[i] == True])\n",
    "numFN = lambda y,yhat: sum([1 for i in range(len(y)) if y[i] == True and yhat[i] == False])\n",
    "numFP = lambda y,yhat: sum([1 for i in range(len(y)) if y[i] == False and yhat[i] == True])\n",
    "numTN = lambda y,yhat: sum([1 for i in range(len(y)) if y[i] == False and yhat[i] == False]) \n",
    "confusion_matrix = lambda y,yhat: np.vstack([[numTP(y,yhat),numFN(y,yhat)],[numFP(y,yhat),numTN(y,yhat)]])\n",
    "error_rate = lambda y,yhat: (numFN(y,yhat) + numFP(y,yhat)) / len(y)\n",
    "error_rate2 = lambda y,yhat: np.average(y != yhat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# y,yhat = np.random.choice(a = [True, False], size = (100)),np.random.choice(a = [True, False], size = (100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "#pseudo-seed to test between skl and Julia\n",
    "y = np.array([False,  True, False,  True, False,  True, False,  True,  True,\n",
    "        True,  True,  True, False,  True,  True,  True,  True, False,\n",
    "       False,  True,  True,  True,  True,  True,  True, False, False,\n",
    "       False,  True,  True,  True,  True,  True, False,  True,  True,\n",
    "        True,  True,  True,  True,  True,  True,  True,  True, False,\n",
    "        True,  True, False, False,  True,  True,  True,  True, False,\n",
    "        True, False,  True, False,  True, False,  True,  True,  True,\n",
    "        True, False,  True,  True, False, False, False, False,  True,\n",
    "        True,  True, False, False,  True, False,  True,  True, False,\n",
    "        True, False,  True, False, False, False,  True, False, False,\n",
    "        True,  True, False, False,  True,  True, False, False,  True,\n",
    "       False])\n",
    "yhat = np.array([False,  True,  True,  True,  True,  True, False, False, False,\n",
    "       False, False, False, False,  True,  True,  True, False, False,\n",
    "        True,  True, False,  True,  True, False,  True, False,  True,\n",
    "       False,  True,  True, False,  True, False, False,  True, False,\n",
    "       False, False, False,  True,  True, False, False,  True, False,\n",
    "        True,  True, False, False,  True, False, False,  True,  True,\n",
    "        True,  True,  True, False, False, False,  True, False,  True,\n",
    "       False, False, False, False, False, False,  True, False,  True,\n",
    "       False,  True, False,  True, False, False, False, False, False,\n",
    "       False, False,  True,  True, False,  True, False,  True,  True,\n",
    "       False,  True,  True, False,  True,  True, False,  True,  True,\n",
    "        True])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[32, 30],\n",
       "       [15, 23]])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "array([[32, 30],\n",
       "       [15, 23]])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "0.45"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "0.45"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "confusion_matrix(y,yhat)\n",
    "np.flip(sklm.confusion_matrix(y,yhat)) #sklm implementation is flipped relative to VMLS\n",
    "error_rate(y,yhat)\n",
    "error_rate2(y,yhat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.axes._subplots.AxesSubplot at 0x1a1c180f50>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAFqCAYAAAAk+oAqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAATHklEQVR4nO3de5DdZX3H8feXXAgKGpSLGILg7eA9FEQGitKUQtSOVqsoWKpcXItllKqjVovo1KmKAqWtqEvDxSnFGzhFVGyodBAv4RKDgOGIImoAExBQouV6vv1jD3ZJs3s22ew+eZ59v5jfzJ7f+Z3feVzPfPab7/P8ficyE0nS9Nuq9AAkaaYygCWpEANYkgoxgCWpEANYkgoxgCWpkNmlB9CqTqczD7gc2JqR3/OXut3uSZ1O5zxgH+BB4ErgLd1u98FyI9V0GudzsQfwOeAJwArgyG63+0C5kWo6WAFPnfuBxd1u9wXAImBJp9PZDzgP2BN4HrANcGy5IaqAsT4XHwNO63a7zwDuBo4pOEZNk4EVcETsCbwSWAAkcBtwUWaumuKxVa3b7Sawrv9wTn/Lbrf7tUeO6XQ6VwK7FhieChnrcwEsBo7o7z8X+CDwqeken6bXuBVwRLyHkX8WBSP/XL6q//P5EfHeqR9e3TqdzqxOp7MSWAss63a7y0c9Nwc4Erik1PhUxvqfC+AnwD3dbveh/iGrGSl41LhBFfAxwHMy81E9yog4FbgB+OhUDawF3W73YWBRp9OZD3y50+k8t9vtXt9/+gzg8m63+61yI1QJ638ugGdt4DDvETADxHj3goiIG4FDM/Nn6+1/CvCfmdkZ43VDwBDAGad8eO9j//LwzTfiSp1x1nlsM29rjjriNZxx1nnc+KOf8I//8HdstdXMbMM/dN03Sw9hi/Dpi7/NvLlzOOcbV3LpyW9l9qytuPbmW/n0xd/hU297benhTbtt/ujYmOw5Hrzz5kn98Zqzw1MnPYaJGlQBnwD8V0TcBPyiv2834OnA8WO9KDOHgWGY/C+jVnfdfQ+zZ8/mcdtty33338/3rvo+R//Fa/nSRZfw7eXXsPSfPjJjw3cmu+ve3zF71lY87jHzuO+BB1l+48846pAXsU9nIZeu6LLkhc/iK9+9gYOe//TSQ61X7+HSI5iwcQM4My+JiGcC+zLSkwpG+lNXZWY9/ysLuONXd/P+D3+Ch3s9spccuvhADjrgRbzgxS9nl5134g1D7wDg4Jfsz3FHv6HwaDVd7vz1Ok489+v0ej16CYfs3eHFz38aT93libznX7/CJy+6gs7CnXjVAc8rPVRNg3FbEJvDTK2ANT5bENqQzdKCWNOdXAti584W04KQpLr0eqVHMGEGsKSmZNYTwM4CSVIhVsCS2mILQpIKqagFYQBLaksr64AlqToVVcBOwklSIVbAktriJJwklVHTOmADWFJbrIAlqZCKKmAn4SSpECtgSW1xHbAkFVJRC8IAltSWiibh7AFLUiFWwJLaYgtCkgqpqAVhAEtqSk3fF2wAS2pLRS0IJ+EkqRArYEltsQcsSYVU1IIwgCW1xUuRJamQiipgJ+EkqRArYEltcRJOkgqpqAVhAEtqS0UVsD1gSSrEAJbUll5vctsAETEvIq6MiGsj4oaI+FB//x4RsTwiboqIz0fE3EHnMoAlNSXz4UltE3A/sDgzXwAsApZExH7Ax4DTMvMZwN3AMYNOZABLassUV8A5Yl3/4Zz+lsBi4Ev9/ecCfzboXAawpLZkb1JbRAxFxNWjtqH13yIiZkXESmAtsAz4CXBPZj7UP2Q1sGDQUF0FIUmjZOYwMDzgmIeBRRExH/gy8KwNHTbovQxgSW2ZxmVomXlPRPw3sB8wPyJm96vgXYHbBr3eFoSktkyyBTFIROzYr3yJiG2Ag4FVwGXAa/qHvRH4j0HnsgKW1Japr4B3Ac6NiFmMFLFfyMyLI+KHwOci4sPA94Glg05kAEtqyxRfipyZPwD22sD+m4F9N+ZctiAkqRArYEltqeheEAawpLYYwJJUSEW3o7QHLEmFWAFLaostCEkqpKIWhAEsqS1WwJJUSEUVsJNwklSIFbCkttiCkKRCDGBJKiQH3gd9i2EAS2pLRRWwk3CSVIgVsKS2VFQBG8CS2lLROmADWFJbKqqA7QFLUiFWwJLa4jI0SSqkohaEASypLQawJBVS0SoIJ+EkqRArYElNyZ6TcJJUhj1gSSqkoh6wASypLRW1IJyEk6RCrIAltcUesCQVYgBLUiEV3QvCHrAkFWIFLKkttiAkqZCKlqEZwJLa4oUYklRIRRWwk3CSVIgVsKSmpJNwklRIRS0IA1hSWyqahLMHLEmFWAFLaostCEkqxEk4SSrECliSCnESTpI0iBWwpLbYgpCkMqb6SriIWAh8FngS0AOGM/P0iPg80OkfNh+4JzMXjXcuA1hSW6a+An4IeGdmroiI7YBrImJZZr7ukQMi4hTg14NOZABLassUB3Bm3g7c3v/53ohYBSwAfggQEQEcBiwedC4n4SRpE0XE7sBewPJRuw8E1mTmTYNebwUsqS2TXIYWEUPA0Khdw5k5vIHjtgUuAE7IzN+Meupw4PyJvJcBLKktk2xB9MP2/wXuaBExh5HwPS8zLxy1fzbwamDvibyXASypKTnFPeB+j3cpsCozT13v6YOBGzNz9UTOZQ9YkjbOAcCRwOKIWNnfXtZ/7vVMsP0AVsCSWjP1qyCuAGKM5960MecygCW1xbuhSVIhXoosSYVUFMBOwklSIVbAkpqSWU8FbABLaktFLQgDWFJbDGBJKmOqr4TbnJyEk6RCrIAltaWiCtgAltSWei6EM4AltcUesCRpICtgSW2pqAI2gCW1xR6wJJVRUw/YAJbUlooqYCfhJKkQK2BJTbEFIUmlVNSCMIAlNSUNYEkqpKIAdhJOkgqxApbUFFsQklSKASxJZdRUAdsDlqRCrIAlNaWmCtgAltQUA1iSSskoPYIJM4AlNaWmCthJOEkqxApYUlOyZwtCkoqoqQVhAEtqSjoJJ0ll1FQBOwknSYVYAUtqipNwklRI1vOVcAawpLbUVAHbA5akQqyAJTWlpgrYAJbUFHvAklSIFbAkFVLTlXBOwklSIVbAkppS06XIBrCkpvQqakEYwJKaYg9YkgrJXkxqGyQiFkbEZRGxKiJuiIi3r/f8uyIiI2KHQeeyApakjfMQ8M7MXBER2wHXRMSyzPxhRCwE/gT4+UROZAUsqSmZk9sGnz9vz8wV/Z/vBVYBC/pPnwa8G5jQ5SBWwJKaMtkLMSJiCBgatWs4M4fHOHZ3YC9geUS8Arg1M6+NmNgYDGBJTZnsKoh+2G4wcEeLiG2BC4ATGGlLvB84ZGPeyxaEJG2kiJjDSPiel5kXAk8D9gCujYhbgF2BFRHxpPHOYwUsqSlTvQwtRvoLS4FVmXnqyHvmdcBOo465BdgnM+8c71xWwJKaMtWTcMABwJHA4ohY2d9etiljtQKW1JSpvhIuM68Axn2TzNx9IucygCU1xSvhJEkDWQFLaorfiDHKNk8+cKrfQhW6dPv9Sw9BW6CD1hw76XN4NzRJKqSmHrABLKkpNVXATsJJUiFWwJKaUtEcnAEsqS01tSAMYElNqWkSzh6wJBViBSypKRV9K70BLKktOf59crYoBrCkpvQqWgZhAEtqSq+iCthJOEkqxApYUlPsAUtSIa6CkKRCaqqA7QFLUiFWwJKaYgtCkgoxgCWpkJp6wAawpKb06slfJ+EkqRQrYElNqelSZANYUlMquhePASypLa6CkKRCelFPC8JJOEkqxApYUlPsAUtSIfaAJakQL8SQJA1kBSypKV6IIUmFOAknSYXU1AM2gCU1paZVEE7CSVIhVsCSmmIPWJIKsQcsSYXU1AM2gCU1paYAdhJOkgqxApbUlLQHLEll1NSCMIAlNaWmALYHLEmFGMCSmpKT3AaJiIURcVlErIqIGyLi7f39r+0/7kXEPhMZqy0ISU2ZhgsxHgLemZkrImI74JqIWAZcD7wa+MxET2QAS2rKVPeAM/N24Pb+z/dGxCpgQWYuA4iN+FZmA1hSU6ZzEi4idgf2ApZvyuvtAUvSKBExFBFXj9qGxjhuW+AC4ITM/M2mvJcVsKSmTPZuaJk5DAyPd0xEzGEkfM/LzAs39b0MYElNmepJuBhp8i4FVmXmqZM5lwEsqSnT0AM+ADgSuC4iVvb3vQ/YGvhnYEfgqxGxMjMPHe9EBrCkpkz1Ddkz8woY86uXv7wx53ISTpIKsQKW1JReRV9KZABLakpNN+MxgCU1pZ761x6wJBVjBSypKbYgJKkQv5ZekgpxFYQkFVJP/DoJJ0nFWAFLaoqTcJJUiD1gSSqknvg1gCU1pqYWhJNwklSIFbCkptgDlqRC6olfA1hSY+wBS5IGsgKW1JSsqAlhAEtqSk0tCANYUlNcBSFJhdQTv07CSVIxVsCSmmILQpIKcRJOkgpxGZokFVJTBewknCQVYgUsqSm2ICSpkJpaEAawpKb0sp4K2B6wJBViBSypKfXUvwawpMZ4JZwkFeIqCEkqpKZVEE7CSVIhVsCSmmIPWJIKsQcsSYXU1AM2gCU1Jb0STpI0iBWwpKY4CSdJhdgDlqRCaloFYQ9YkgqxApbUFHvAklSIy9AkqZDeJLdBIuKsiFgbEdeP2rcoIr4XESsj4uqI2HciYzWAJTUlJ/nfBJwDLFlv38nAhzJzEfCB/uOBDGBJ2giZeTlw1/q7gcf1f348cNtEzmUPWFJTJjsJFxFDwNCoXcOZOTzgZScA34iITzBS2O4/kfcygCU1ZbKTcP2wHRS46zsO+JvMvCAiDgOWAgcPepEtCElN6ZGT2jbRG4EL+z9/EXASTpKmyW3AS/o/LwZumsiLbEFIaspUX4ocEecDBwE7RMRq4CTgzcDpETEbuI9H95DHZABLakpvii/EyMzDx3hq7409lwEsqSn1XAdnAEtqTE33gnASTpIKsQKW1JSaKmADWFJTarobmgEsqSlWwJJUiF9JJEkayApYUlPsAUtSIfaAJamQmipge8CSVIgVsKSm2IKQpEJqWoZmAEtqylTfjnJzMoAlNaWmCthJOEkqxApYUlNsQUhSITW1IAxgSU2xApakQmqqgJ2Ek6RCrIAlNcUWhCQVUlMLwgCW1JTMXukhTJg9YEkqxApYUlO8G5okFVLTDdkNYElNsQKWpEJqqoCdhJOkQqyAJTXFCzEkqRAvxJCkQmrqARvAkppS0yoIJ+EkqRArYElNsQUhSYW4CkKSCqmpArYHLEmFWAFLakpNqyAMYElNqakFYQBLaoqTcJJUSE2XIjsJJ0mFWAFLaootCEkqxEk4cebwKbz8ZQez9o47WbTXHwPwgRPfwTFHH8Edd94FwIknfpSvX/LNksPUNNv6yU9kz385nrk7zodectu/XcqtZ36N3d/zOnZY8kLoJQ/c+WtufNsneWDN3aWHW6WaesAx1X8tZs9dUM9vYzM68A9fxLp1v+Xss09/VACvW/dbTj3tM4VHV96l2+9feghFzN1pPnN33p511/2UWY+dx97LPsb1b/o499/2Kx5e9z8ALDj2pTz2mbvyo3efWXi00++gNV+MyZ5j7ta7TipzHrh/9bhjiIizgD8F1mbmc/v7Pgi8Gbijf9j7MvNrg97LSbgp8q0rlnPX3feUHoa2MA+svYd11/0UgId/ex+/u+lWtn7SE34fvgCzHrM1Ff0reiY6B1iygf2nZeai/jYwfMEAnnZvPe4oVlyzjDOHT2H+/MeXHo4KmrdwR7Z97h78ZsVNAOzxt4ez34pPsfOfH8gtJ3++8OjqlZmT2iZw/suBuzbHWDc5gCPiqM0xgJnk05/5LM/cc3/23ucQfvnLtXz85A+UHpIKmfWYeTxn6bv48Yln/776/elHzud7f3Acay74FguO3lCBpYnISW6TcHxE/CAizoqI7Sfygk3uAUfEzzNztzGeGwKG+g+HM3N4k96kfrsDFwOP9ImGRv0uHvWcZpQ5jPx//w3g1PU+FwBPAb6Kn40i1ssv2ECGRcTuwMWjesA7A3cykuF/D+ySmUcPeq9xV0FExA/GegrYeazX9Qc7U0N3TLvttttf83+/l1cB1xccjsoIYCmwCji1v28IuAy4qf/4FcCN0z80wablV2aueeTniDiTkT+wAw1ahrYzcCiw/nqYAL6zMQOcgc4HDgJ2AFYDJ51++um7Atcx8lfyFuAtpQanYg4AjmTkc7AS4LDDDnss8FGgA/SAnwF/VWqA2ngRsUtm3t5/OOHiatwWREQsBc7OzCs28Ny/Z+YRmzLYmSoirs7MfUqPQ1sWPxd1iYjRxdUa4KT+40WMKq5GBfLY56rpqpHabaDXJ/m5mMEMYEkqxHXAklSIATxNImJJRHQj4scR8d7S41F5/fWiayPC1TAzlAE8DSJiFvBJ4KXAs4HDI+LZZUelLcA5bPiSVs0QBvD02Bf4cWbenJkPAJ8DXll4TCpsc17SqjoZwNNjAfCLUY9X9/dJmsEM4OmxodvbufxEmuEM4OmxGlg46vGuwG2FxiJpC2EAT4+rgGdExB4RMRd4PXBR4TFJKswAngaZ+RBwPCN3v1oFfCEzbyg7KpXWv6T1u0AnIlZHxDGlx6Tp5ZVwklSIFbAkFWIAS1IhBrAkFWIAS1IhBrAkFWIAS1IhBrAkFWIAS1Ih/wsPkiJIFn7L3QAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x432 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "#optional confusion matrix vis template\n",
    "import seaborn as sn\n",
    "import pandas as pd\n",
    "\n",
    "df_cm = pd.DataFrame(confusion_matrix(y,yhat) )\n",
    "sn.heatmap(df_cm, annot=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# 14.2 Least Squares Classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/vb/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:10: FutureWarning: `rcond` parameter will change to the default of machine precision times ``max(M, N)`` where M and N are the input matrix dimensions.\n",
      "To use the future default and silence this warning we advise to pass `rcond=None`, to keep using the old, explicitly pass `rcond=-1`.\n",
      "  # Remove the CWD from sys.path while we load stuff.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.07333333333333333"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "0.07333333333333333"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#fHat can be made via a regression model, and the > comparator\n",
    "fTilde = lambda x: x@beta + v\n",
    "fHat = lambda x: fTilde(x) > 0\n",
    "\n",
    "\n",
    "iris = np.vstack([irisD[\"setosa\"],irisD[\"versicolor\"],irisD[\"virginica\"]])\n",
    "#if k == virginica: y[k] = True (1) ; else False (0)\n",
    "y = np.hstack([np.full(50, False),np.full(50, False),np.full(50, True)])\n",
    "A = np.hstack([np.ones((150,1)), iris])\n",
    "theta = npl.lstsq(A,2*y-1)[0]\n",
    "yhat = np.matmul(A,theta) > 0 #regression classifier \n",
    "\n",
    "C = confusion_matrix(y,yhat)\n",
    "(C[0,1] + C[1,0]) / len(y)\n",
    "np.average(y != yhat)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 14.3 Multi-class Classifiers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "K = 4\n",
    "y = np.random.randint(K,size=100)\n",
    "yhat = np.random.randint(K,size=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[5., 9., 6., 6.],\n",
       "       [9., 4., 9., 6.],\n",
       "       [6., 5., 3., 4.],\n",
       "       [7., 6., 6., 9.]])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "0.79"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "0.79"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "error_rate = lambda y,yhat: np.average(y != yhat)\n",
    "def confusion_matrix(y, yhat, K):\n",
    "    C = np.zeros((K,K))\n",
    "    for i in range(K):\n",
    "        for j in range(K):\n",
    "            C[i,j] = sum(np.logical_and(y == i, yhat == j))\n",
    "    return C\n",
    "C = confusion_matrix(y,yhat,K)\n",
    "C\n",
    "error_rate(y,yhat)\n",
    "1-np.sum(np.diag(C))/np.sum(C)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.axes._subplots.AxesSubplot at 0x1a1c2aa110>"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVkAAAFqCAYAAACwHCnYAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAZeUlEQVR4nO3deXTdZZ3H8c83uWnadIHRUgqliwG5QAuEoRQq2FEqUBVFFkcY3BkLY2Vrz3FQYBAFhwMznYIelRSccQDrOCzKCFagWqqydWFfIpgptLalLbTEtnTNd/64gdZObhZunzxPn7xf59xjlie53/68fPLN9/n9fjF3FwAgjKrYBQBAzghZAAiIkAWAgAhZAAiIkAWAgAhZAAioELuAHrBY0p8lbZO0VdLYqNVEVCwWL5T0JUkmaWZTU9OMyCXFtKekmySNkeSSvijp4agVxcOxCKi3dLIflNSg3h2wY1QK2HGSDpd0crFYfG/cqqK6XtJsSQepdDyej1tOVByLbjKzC83sGTN71swu6mhtp52smR0k6RRJw1T6KbdM0t3uzv8Ru5eDJT3S1NS0QZKKxeKDkk6VdG3UquIYJGmCpM+3vb+57dEbcSy6ycx2bFg2S5ptZve4+4vtre+wkzWzf5T0E5V+vXxM0vy2t2eZ2SW7svCAXNJ9khZKmhy5lpiekTShWCy+u1gs1kn6iKThkWuKpV7SKkn/LulxlX5V7h+1ong4Ft13sKRH3H2Du2+V9FbD0i7r6LJaM/uDpNHuvmWnj/eR9Ky77w6/bu6rUvc9RNL9ks6XNC9qRZEUi8VzJE2RtE7Sc5LebGpqujhuVVGMlfSIpGMlParSr8stki6PWVQkHItuMrODJf1c0nhJb0qaI2mBu5/f7vpOQvYFSSe5+8s7fXykpPvcvVjm6yarrWs85V3jjjxqwAHv4J+y673/otO0ecNGPdp4b5Tn/+rCb0V53vbM+MF/aOiQwTrztJOjPP+TDVOjPK8kFfbaUwfffa2eHl/6xWbAuEM0dMppeulzV0WrKZbUjsXYpT+zSr/HltXNFd2Qpc9e+5+rv/ytt9HdG3dcY2b/r2Fx93Ybls42vi6SNMfMfmlmjW2P2Sol94XlvsjdG919rLuPjRmwNf1q1ad/37ffrp9wqFY1LY1WT2yvrVkrSVq+YqXmPPh7ffhDfxO5oji2rlqrzctWq7Z+X0nSoOMO08YXl0SuKo4sj0XrtooeO+ZX26Nx56dw95vd/a/dfYKk1yW1O4+VOtn4cvfZZnagSgPeYSrNY5dKmu/u2yo6ED2g/+BBOqOx9MOlqlCtZ3/+kJoffCpyVfFc/PWrtLalRYVCQZdO+7L2GDQwdknRvHL5TNV/Z6qsT0GbXn5Vi6fdELukaDgW3WdmQ9x9pZmNkHSaSqOD9teGvtXh1SPP5l6KbVIaF8QWc1yAdO2SccGrTRVlTs3exU5rMLPfSnq3pC2Sprr7nHJre8PFCAB6k9bW4E/h7u/v6lpCFkBW3MOHbHf0liu+ACAKOlkAeemBcUF3ELIA8pLYuICQBZCX1rTOLiVkAeQlsU6WjS8ACIhOFkBe2PgCgHBSO0+WkAWQFzpZAAgosU6WjS8ACIhOFkBeOE8WAAJKbFxAyALIS2IbX8xkASAgOlkAeWFcAAABJTYuIGQBZCW1v/FKyALIS2LjAja+ACAgOlkAeWEmCwABJTYuIGQB5IXLagEgoMQ6WTa+ACAgOlkAeWHjCwACSmxcQMgCyEtinSwzWQAIiE4WQF4S62QJWQBZ4QYxABASnSwABJTY2QVsfAFAQHSyAPLCuAAAAkpsXEDIAsgLnSwABJRYJ8vGFwAERCcLIC+MCwAgIEIWAAJiJgsAvQedLIC8MC4AgIASGxcQsgDy0gOdrJldLOnvJbmkpyV9wd03treWmSyAvHhrZY9OmNkwSRdIGuvuYyRVSzqz3HpCFgC6ryCpn5kVJNVJWtbRwqC+uvBboZ9it3FCw+TYJSTj/icaY5eQjCcbpsYuIS8VjgvMbLKkHf9jbXT3t1+w7v4nM/sXSa9IelPSfe5+X7nvx0wWQF4qDNm2QC3bBZjZX0k6RdJ7JK2V9N9m9ml3v7W99YwLAOTFvbJH5z4k6X/dfZW7b5F0p6T3lVtMJwsgL+HPLnhF0jFmVqfSuGCipAXlFtPJAkA3uPujkm6XtEil07eq1MF4gU4WQF564DxZd79C0hVdWUvIAsgLV3wBQECJ3buAmSwABEQnCyAvXTsNq8cQsgDykti4gJAFkBdCFgACSuzsAja+ACAgOlkAWfFWNr4AIBxmsgAQUGIzWUIWQF4SGxew8QUAAdHJAsgLM1kACIiQBYCAErt3ATNZAAiIThZAXhgXAEBAiZ3CRcgCyAsXIwBAQIl1smx8AUBAdLIAsuJsfAFAQImNCwhZAHlJbOOLmSwABEQnCyAvjAsAICA2vgAgIDpZAAiIjS8A6D3oZAHkhXEBAITDFV8AEBKdLAAElFjIsvEFAAHRyQLIS2KncBGyAPKS2LiAkAWQFU8sZJnJAkBAdLIA8pJYJ0vIAsgLFyMAQEB0sgAQECHbs2756c90x92z5e464+OT9JlPnRq7pKiqqqp0473f0+oVq/W1z18Wu5xoeF1sVz2ov0ZeN0X9iiMkdy2e9l2tX9QUu6xsZB2yLzYv1h13z9asm2aoplCj86ZdpgnvG6eRw4fFLi2a0885VS+/9Ir6D6iLXUo0vC7+0vArz1HL3EVqPvdaWU1BVf1qY5dUEfewnayZFSX91w4fqpf0T+4+o731WZ/C1bx4iQ4bfZD69e2rQqFaYxsO1Zx5D8UuK5q99hmsYyYerXt+fG/sUqLidbFd1YB+Gnj0aK2e9YAkybds1baW9ZGrqlCrV/bohLs3uXuDuzdIOlLSBkl3lVv/jkPWzL7wTr+2pxxQP1ILn3xGa99o0ZsbN+q3D8/XildXxS4rmq9848u68eqZwX/Sp47XxXa1I4Zq6+tvaNT0C3TI7Okaed2U3b6TDR2yO5ko6Y/u/nK5BZV0sleW+4SZTTazBWa24Kb/nFXBU1Rm/1Ej9MWzP6kvXfR1nTf1ch14QL2qq6uj1RPT+IlHa83qtfrD0y/GLiU6XhfbWaFKdWP216pbfqnnJk1V64aNGjrl9NhlVcRbvaLHjvnV9pjcwdOdKanDkOtwJmtmT5X7lKS9y/4j3RslNUrSltXNUdum0z92kk7/2EmSpBk/+A8NHTI4ZjnRjDlqjI49cbyOOX6c+tT2Ud3AOl16wyW6+oJrYpcWBa+Lks3LX9Pm5a9p/eOlH75r7nlYQ6ecFrmquHbMr46YWR9JH5f0tY7WdbbxtbekkySt2fn7S9othlivrVmrd//Vnlq+YqXmPPh73Xrj9NglRTHzmps185qbJUkN4w/Xp879ZK8NWInXxVu2rlqrzctWq7Z+X21qXqZBxx2mjS8uiV1WZXruFK4PS1rk7q92tKizkP2FpAHu/sTOnzCzue+8tp5z8dev0tqWFhUKBV067cvaY9DA2CUhAbwutnvl8pmq/85UWZ+CNr38qhZPuyF2SZXpuQu+zlInowJJstCbILHHBSk5oaGj0U7vcv8Tnf421ms82TA1dgnJGLv0Z1bp91h79vEVZc6et/260xrMrE7SEkn17v5GR2uzPk8WAEJw9w2S3t2VtYQsgLxwWS0ABJTWTbgIWQB5Se0vIxCyAPKSWCeb9b0LACA2OlkAWWFcAAAhJTYuIGQBZMUJWQAIKLGQZeMLAAKikwWQFcYFABASIQsA4aTWyTKTBYCA6GQBZCW1TpaQBZAVQhYAQvKK/7jCLkXIAshKap0sG18AEBCdLICseCvjAgAIJrVxASELICvOxhcAhJNaJ8vGFwAERCcLICtsfAFAQJ7Wn/giZAHkJbVOlpksAAREJwsgK6l1soQsgKwwkwWAgOhkASCg1K74YuMLAAKikwWQldQuqyVkAWSlNbFxASELICupzWQJWQBZSe3sAja+ACAgOlkAWeFiBAAIKLVxASELICupnV3ATBYAAqKTBZAVTuECgIDY+AKAgJjJAkBA7lbRoyvMbE8zu93MXjCz581sfLm1dLIA0H3XS5rt7meYWR9JdeUWErIAshJ6JmtmgyRNkPT50vP5Zkmby60PHrJPNkwN/RS7jROq945dQjL67fv+2CUk47ghB8cuIRlzd8H36IGZbL2kVZL+3cwOl7RQ0oXuvr69xcxkAWSl0pmsmU02swU7PCbv9BQFSX8t6fvufoSk9ZIuKVcP4wIAWam0k3X3RkmNHSxZKmmpuz/a9v7t6iBk6WQBoBvcfYWkJWZWbPvQREnPlVtPJwsgKz10LcL5km5rO7OgWdIXyi0kZAFkpScuRnD3JySN7cpaQhZAVlK7dwEzWQAIiE4WQFYS+4vghCyAvLjSGhcQsgCy0sqtDgEgnNbEOlk2vgAgIDpZAFlhJgsAAXF2AQAElFony0wWAAKikwWQFcYFABAQIQsAAaU2kyVkAWSlNa2MZeMLAEKikwWQldQuqyVkAWQlsfvDELIA8sLZBQAQUKulNS5g4wsAAqKTBZAVZrIAEBAzWQAIiIsRAKAXoZMFkBUuRgCAgNj4AoCAUpvJErIAspLa2QVsfAFAQHSyALLCTBYAAmImCwABpTaTJWQBZCW1kGXjCwACopMFkBVnJgsA4aQ2LiBkAWQltZBlJgsAAdHJAsgKFyMAQEBcjAAAAaU2kyVkAWQltZBl4wsAAqKTBZAVNr56WPWg/hp53RT1K46Q3LV42ne1flFT7LKimPK7Gdq8fqN8W6tat23TDz92eeySoqitrdXcX9+hPrW1KhSqdeed9+jKb/5r7LKiqqqq0o33fk+rV6zW1z5/WexyKsLGVw8bfuU5apm7SM3nXiurKaiqX23skqK69cyr9OaadbHLiGrTpk360Il/q/XrN6hQKGje3Ls0e/Zv9Ohji2KXFs3p55yql196Rf0H1MUupWI9MZM1s8WS/ixpm6St7j623NpOZ7JmdpCZTTSzATt9fFKlhYZWNaCfBh49WqtnPSBJ8i1bta1lfeSqkIL16zdIkmpqCirU1Mg9tV8ye85e+wzWMROP1j0/vjd2KbuEV/johg+6e0NHASt1ErJmdoGkn0s6X9IzZnbKDp/+dvfq6Xm1I4Zq6+tvaNT0C3TI7Okaed2UXt7Juv7u1kv0xV9cpSPO+mDsYqKqqqrSgvn3afmfntKcOfP02PzHY5cUzVe+8WXdePXMXv2DJqTOOtkvSTrS3T8h6QOSLjezC9s+V3byYWaTzWyBmS24c/3iXVLoO2GFKtWN2V+rbvmlnps0Va0bNmrolNOj1RPbj067Ujd/9DL95HPX6sjPnqDh4w6KXVI0ra2tGnvUiRr5nrE6auwRGj26GLukKMZPPFprVq/VH55+MXYpu0yrvKLHjvnV9pjcztO4pPvMbGGZz7+ts5lstbuvkyR3X2xmH5B0u5mNVAch6+6NkholacF+n4j243Hz8te0eflrWv946QW05p6HNXTKabHKiW7dyrWSpA2vtajpVwu0b0O9ljz2QuSq4nrjjRY9OO8hnXTiB/Tss71vQ3TMUWN07Injdczx49Snto/qBtbp0hsu0dUXXBO7tHes0pnsjvnVgWPdfZmZDZF0v5m94O7z2lvYWSe7wswadnjydZJOljRY0qHdqDuKravWavOy1aqt31eSNOi4w7TxxSWRq4qjpl+t+vTv+/bb9RMO1aqmpZGrimPw4Hdpjz0GSZL69u2rice/X01Nf4xcVRwzr7lZnzzqLJ05/tP65pSr9fjvn9itA1bqmZmsuy9r+9+Vku6SNK7c2s462c9K2rrTN98q6bNmdmMX64nqlctnqv47U2V9Ctr08qtaPO2G2CVF0X/wIJ3ReLEkqapQrWd//pCaH3wqclVx7LPP3vrhzTNUXV2lqqoq3X77/+ieex+IXRZ2E2bWX1KVu/+57e0TJX2z7PrQw+6Y44LU/Kq6f+wSknHF8rmxS0jGcUMOjl1CMuYufaDis1y/MfLsijLnGy/f1mENZlavUvcqlRrVH7v71eXWZ3+eLIDeJfTFCO7eLOnwrq4nZAFkpTWxC2sJWQBZSStiuQsXAARFJwsgK6ndT5aQBZAVZrIAEFBaEUvIAshMauMCNr4AICA6WQBZYSYLAAGlFbGELIDMMJMFgF6EThZAVjyxgQEhCyArqY0LCFkAWeHsAgAIKK2IZeMLAIKikwWQFcYFABAQG18AEBCncAFAQKl1smx8AUBAdLIAssK4AAACSm1cQMgCyEqrp9XJMpMFgIDoZAFkJa0+lpAFkBmu+AKAgDi7AAACSu3sAja+ACAgOlkAWWEmCwABMZMFgIBSm8kSsgCy4lzxBQC9B50sgKyw8QUAATGTBYCAUju7gJksAAREJwsgK8xkASCg1E7hImQBZIWNLwAIiI0vAOhF6GQBZKWnNr7MrFrSAkl/cveTy60jZAFkpQc3vi6U9LykQR0tYlwAICut8ooeXWFm+0n6qKSbOlsbvJN97wnrQj/F7uP+2AWk46QhR8UuIRmHPzE9dgnYgZlNljR5hw81unvjTstmSPqqpIGdfT/GBQCyUunZBW2BunOovs3MTpa00t0XmtkHOvt+hCyArLSGn8keK+njZvYRSX0lDTKzW9390+0tZiYLICte4aPT7+/+NXffz91HSTpT0q/LBaxEJwsgM9y7AAAy4e5zJc3taA0hCyArdLIAEBB34QKAgOhkASAg7sIFAL0InSyArDCTBYCAmMkCQECpdbLMZAEgIDpZAFlhXAAAAaV2ChchCyArPXCrw24hZAFkJbVOlo0vAAiIThZAVhgXAEBAqY0LCFkAWaGTBYCAUutk2fgCgIDoZAFkhXEBAASU2riAkAWQFffW2CX8BWayABAQnSyArHAXLgAIKLWbdhOyALJCJwsAAaXWybLxBQAB0ckCyAoXIwBAQFyMAAABpTaTJWQBZCW1swvY+AKAgOhkAWSFcQEABMTZBQAQUGqdLDNZAAiIThZAVlI7u4CQBZCV1MYFhCyArLDxBQABpXZZLRtfABAQnSyArDAuAICA2PjqQVVD91PdP1y2/f299tHGu36kzfffGbGqeKoH9dfI66aoX3GE5K7F076r9YuaYpcVBcdiu1t++jPdcfdsubvO+PgkfeZTp8YuqSKpzWSzDtnWFUu17orzSu9YlQb+20+0ZdHv4hYV0fArz1HL3EVqPvdaWU1BVf1qY5cUDcei5MXmxbrj7tmaddMM1RRqdN60yzThfeM0cviw2KW9Y6E7WTPrK2mepFqVMvR2d7+i3Ppes/FVOOQIta5cJn9tZexSoqga0E8Djx6t1bMekCT5lq3a1rI+clVxcCy2a168RIeNPkj9+vZVoVCtsQ2Has68h2KXlbpNko5398MlNUiaZGbHlFvcacia2TgzO6rt7UPMbKqZfWSXldtDao7+oLY8+pvYZURTO2Kotr7+hkZNv0CHzJ6ukddN6bXdG8diuwPqR2rhk89o7RstenPjRv324fla8eqq2GVVxN0renTh+7u7r2t7t6btUfYLOwxZM7tC0g2Svm9m/yzpu5IGSLrEzC7t4r85vuqCCg3jtWX+g7EricYKVaobs79W3fJLPTdpqlo3bNTQKafHLisKjsV2+48aoS+e/Ul96aKv67ypl+vAA+pVXV0du6yKeIWPrjCzajN7QtJKSfe7+6Nl13aU3Gb2tErtcK2kFZL2c/cWM+sn6VF3P6zM102WNLnt3UZ3b+xi7aGc0tzc/O36+vrRkeuIaaikRySNMrPJ7v68pEskfTRuWVFwLNphZpMPPPDAUZKWNjU1fS92PbHslF9SBxlmZntKukvS+e7+THtrOhsXbHX3be6+QdIf3b1Fktz9TUmt5b7I3RvdfWzbI3bAStJZ119/fV3sIiJbIWmJpKJKL6CJkp6LWlE8HIsdFIvFIZJUU1PzFUmnSZoVt6K4dsqvDjPM3ddKmitpUrk1nYXsZjN7K5yOfOuDZraHOgjZxNRJOuHWW29dG7uQBJwv6bampqZDVPoN5duR64mJY7HdHcVi8blhw4YdIGlKU1PTmtgFpczM9mrrYNX2W/2HJL1Qdn0n44Jad9/UzscHS9rH3Z+uvOSeYWYL3H1s7DpSwLHYjmOxHceia8zsMEk/klStUqP6U3f/Zrn1HZ4n217Atn18taTVFdQZQwpji1RwLLbjWGzHsegCd39K0hFdXd9hJwsAqEyvuRgBAGLIPmTNbJKZNZnZS2Z2Sex6YjKzH5rZSjNr91ST3sLMhpvZb8zseTN71swujF1TLGbW18weM7Mn247FlbFryk3W4wIzq5b0B0knSFoqab6ks9y9V56uY2YTJK2T9J/uPiZ2PbGY2T4qbdwuMrOBkhZK+kRvfF2YmUnq7+7rzKxG0u8kXejuj0QuLRu5d7LjJL3k7s3uvlnSTySdErmmaNx9nqTXY9cRm7svd/dFbW//WdLzknbfO6JUoLuXiKL7cg/ZYSqddP6Wpeql/zGhfWY2SqWd4rKXReauO5eIovtyD1lr52P8lIYkycwGSLpD0kVvXc3YG7Vd1dkgaT9J48ys146SQsg9ZJdKGr7D+/tJWhapFiSkbf54h6Tb3L133sV9J125RBTdl3vIzpf0XjN7j5n1kXSmpLsj14TI2jZ7bpb0vLtPj11PTN29RBTdl3XIuvtWSV+R9CuVNjd+6u7Pxq0qHjObJelhSUUzW2pm58SuKZJjJX1G0vFm9kTbY7e7R/Iuso+k35jZUyo1Jfe7+y8i15SVrE/hAoDYsu5kASA2QhYAAiJkASAgQhYAAiJkASAgQhYAAiJkASAgQhYAAvo/wUvf7M8wVKYAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x432 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "df_cm = pd.DataFrame(confusion_matrix(y,yhat,K) )\n",
    "sn.heatmap(df_cm, annot=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "#least squares multi-class classifier\n",
    "A = np.random.randn(4,5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[2, 4, 2, 2]"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "array([2, 4, 2, 2])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "[1, 1, 0, 0, 1]"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "array([1, 1, 0, 0, 1])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "row_argmax = lambda u:[np.argmax(u[i,:]) for i in range(np.size(u,0))]\n",
    "col_argmax = lambda u:[np.argmax(u[:,i]) for i in range(np.size(u,1))]\n",
    "row_argmax(A)\n",
    "np.argmax(A,1)\n",
    "col_argmax(A)\n",
    "np.argmax(A,0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "#can use the following to find n-vector predictions of a data set with N examples \n",
    "#stored as nxN 'x', and theta is an nxK \n",
    "fhat = lambda x,theta: row_argmax(x@theta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "K = 4\n",
    "ycl = np.random.randint(K,size=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "#matrix least squares\n",
    "def one_hot(ycl,K):\n",
    "    N = len(ycl)\n",
    "    Y = np.zeros((N,K))\n",
    "    for j in range(K):\n",
    "        Y[np.where(ycl==j),j] = 1\n",
    "    return Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ls_multiclass(X,ycl,K):\n",
    "    n,N = np.size(X),np.size(X) #debugging code left in on julia companion\n",
    "    theta = npl.lstsq(X.T, (2*one_hot(ycl,K)-1))[0]\n",
    "    yhat = np.argmax(np.matmul(X.T,theta),1)\n",
    "    return theta,yhat\n",
    "\n",
    "setosa = irisD[\"setosa\"]\n",
    "versicolor = irisD[\"versicolor\"]\n",
    "virginica = irisD[\"virginica\"]\n",
    "I1 = np.random.permutation(50)\n",
    "I2 = np.random.permutation(50)\n",
    "I3 = np.random.permutation(50)\n",
    "\n",
    "#4x120 data matrix: 4 features per flower, 120 flowers\n",
    "#transposed shape shows 4 arrays, 1 per feature, with 120 data pts in each\n",
    "Xtrain = np.vstack([setosa[I1[:40],:],\n",
    "          versicolor[I2[:40],:],\n",
    "           virginica[I3[:40],:]]).T \n",
    "#add constant feature array 1, now shape 5x120\n",
    "Xtrain = np.vstack([np.ones((1,120)),Xtrain]) \n",
    "Ytrain = np.vstack([np.full(40,1),np.full(40,2),np.full(40,3)]).ravel()\n",
    "\n",
    "Xtest = np.vstack([setosa[I1[40:],:],\n",
    "          versicolor[I2[40:],:],\n",
    "           virginica[I3[40:],:]]).T \n",
    "Xtest = np.vstack([np.ones((1,30)),Xtest]) \n",
    "Ytest = np.vstack([np.full(10,1),np.full(10,2),np.full(10,3)]).ravel()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/vb/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:3: FutureWarning: `rcond` parameter will change to the default of machine precision times ``max(M, N)`` where M and N are the input matrix dimensions.\n",
      "To use the future default and silence this warning we advise to pass `rcond=None`, to keep using the old, explicitly pass `rcond=-1`.\n",
      "  This is separate from the ipykernel package so we can avoid doing imports until\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(array([[-1.00000000e+00, -8.26661810e-01,  2.13489006e+00],\n",
       "        [ 2.66778329e-16,  1.35626321e-01, -1.35118329e-02],\n",
       "        [-6.48120074e-17,  5.02854430e-01, -9.13082470e-01],\n",
       "        [-5.64764927e-17, -4.53575801e-01,  3.82993528e-01],\n",
       "        [-3.96551917e-18, -1.04589530e-01, -8.74847378e-01]]),\n",
       " array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,\n",
       "        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
       "        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2,\n",
       "        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0,\n",
       "        2, 2, 2, 2, 2, 2, 2, 2, 2, 2]))"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "array([[ 0.,  0.,  0.],\n",
       "       [ 0., 40.,  0.],\n",
       "       [ 0.,  1., 39.]])"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "0.3416666666666667"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "array([[0., 0., 0.],\n",
       "       [0., 0., 0.],\n",
       "       [0., 0., 0.]])"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "0.3333333333333333"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "theta,yhat = ls_multiclass(Xtrain,Ytrain,3)\n",
    "theta,yhat\n",
    "Ctrain = confusion_matrix(Ytrain,yhat,3)\n",
    "Ctrain\n",
    "error_train = error_rate(Ytrain,yhat)\n",
    "error_train\n",
    "yhat = row_argmax(np.matmul(Xtest.T,theta))\n",
    "Ctest = confusion_matrix(Ytest,yhat,3)\n",
    "error_test = error_rate(Ytest,yhat)\n",
    "Ctest\n",
    "error_test\n",
    "\n",
    "#performance is 2x worse than VLMS companion, need to revisit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(30,)"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.shape(yhat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
